import pdb

import numpy as np
from scipy.stats import bernoulli, binom
from scipy.sparse import csr_matrix
import random
import numba

random.seed(38138)


# Functions for data generation

def degree_matrix(A):
    Degree_matrix = np.zeros([A.shape[0], A.shape[0]])  # basically dimensions of your graph
    rows, cols = A.nonzero()
    for row, col in zip(rows, cols):
        Degree_matrix[row, row] += 1
    return Degree_matrix


def generate_data(Node_labels, X1):
    return X1 * Node_labels


def generate_conv_data(Node_labels, X_tilde):
    return X_tilde * Node_labels.reshape(-1, 1)


# Define CSBM model

@numba.njit(cache=True)
def adj(label):  # Function is compiled and runs in machine code
    n = len(label)
    ind = []
    for i in numba.prange(n):
        for j in range(i + 1, n):
            ind.append(label[i] == label[j])
    return ind


@numba.njit(cache=True)
def adj2(label, ind):  # Function is compiled and runs in machine code
    n = len(label)
    for i in numba.prange(n):
        for j in range(i + 1, n):
            if label[i] == label[j]:
                ind[(n * i - ((i + 1) * i) // 2) + (j - i - 1)] = True


# https://networkx.org/documentation/stable/reference/generated/networkx.generators.community.stochastic_block_model.html
def CSBM(n, epsilon, p, q):
    Ber = bernoulli(epsilon)
    Node_labels = np.sign(Ber.rvs(n) - 0.5)
    Ber_p = bernoulli(p)
    Ber_q = bernoulli(q)

    # ind = []
    # for i in range(n):
    #     for j in range(i + 1, n):
    #         ind.append(Node_labels[i] == Node_labels[j])
    # ind = np.array(adj(Node_labels))
    iu = np.triu_indices(n, 1)
    ind = (Node_labels[iu[0]]==Node_labels[iu[1]])
#     ind = np.zeros(n * (n - 1) // 2, dtype=bool)
#     adj2(Node_labels, ind)
#     print('Ber indices ready')
    num_p = ind.sum()
    num_q = len(ind) - num_p

    Bp = Ber_p.rvs(num_p).astype(bool)
    Bq = Ber_q.rvs(num_q).astype(bool)
    # print('Ber samples ready')
    data = np.ones((Bp.sum() + Bq.sum()) * 2)
    row = np.concatenate([iu[0][ind][Bp], iu[0][~ind][Bq]])
    col = np.concatenate([iu[1][ind][Bp], iu[1][~ind][Bq]])
    # print('Sparse indices ready')
    A = csr_matrix((data, (np.concatenate([row, col]), np.concatenate([col, row]))), shape=(n, n))
    print(f'#NNZ {A.nnz}')
    return A, Node_labels


def high_dim_Gaussian(Node_labels, d, mu):
    I = np.identity(d) / d
    X_tilde = np.random.multivariate_normal(mu, I, len(Node_labels), 'raise')
    X_temp = X_tilde @ mu
    X = generate_data(Node_labels, X_temp)
    X_tilde = generate_conv_data(Node_labels, X_tilde)
    return X, X_tilde

def high_dim_Laplace(Node_labels, d, mu):
    I = np.identity(d) / d
    X_tilde = np.random.laplace(loc=mu, scale=1.0, size=(len(Node_labels), d))
    X_tilde_label = generate_conv_data(Node_labels, X_tilde)
    X_temp = X_tilde @ (mu * np.ones(d))
    X = generate_data(Node_labels, X_temp)
    return X, X_tilde_label

def high_dim_Gaussian_robustness(Node_labels, d, mu, delta_mu, gamma):
    Ber = bernoulli(1/2)
    I = np.identity(d) / d
    X_tilde = np.random.multivariate_normal(mu, I, len(Node_labels), 'raise')
    X_noise = np.random.multivariate_normal(delta_mu, gamma * I, len(Node_labels), 'raise')
    Noise_label = np.sign(Ber.rvs(len(Node_labels)) - 0.5)
    X_noise_mu  = (X_noise @ mu) * Noise_label
    X_temp = X_tilde @ mu + X_noise_mu
    X = generate_data(Node_labels, X_temp)
    X_tilde = generate_conv_data(Node_labels, X_tilde + X_noise)

    return X, X_tilde

# def one_layer_conv(n, X_tilde, A, mu, R=1, b=0):
#     I = np.identity(n)
#     A_tilde = A + I
#     D = degree_matrix(A_tilde)
#
#     D_inv = np.linalg.pinv(D)
#     mu_2 = np.mat(mu)
#     epsilon_error = 0.00000001
#
#     #Predict with linear transformation & sign function
#     X_conv1 = np.dot(D_inv, A_tilde)
#     X_conv2 = np.dot(X_conv1, X_tilde)
#     X_conv3 = np.dot(X_conv2, mu_2.T)
#     Y1 = np.sign(X_conv3 + epsilon_error)
#     Y1 = np.array(Y1).reshape(-1)
#     return X_conv1, X_conv2, X_conv3, Y1
